"""
Copyright (2023) Bytedance Ltd. and/or its affiliates

Licensed under the Apache License, Version 2.0 (the "License"); 
you may not use this file except in compliance with the License. 
You may obtain a copy of the License at 

    http://www.apache.org/licenses/LICENSE-2.0 

Unless required by applicable law or agreed to in writing, software 
distributed under the License is distributed on an "AS IS" BASIS, 
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
See the License for the specific language governing permissions and 
limitations under the License.

Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py
"""

import torch
import torch.nn.functional as F
from torch import nn

_SOFTMAX_MASKING_CONSTANT = -99999.0

# https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan
def divide_no_nan(x: torch.Tensor, y: torch.Tensor):
    return torch.nan_to_num(x / y, nan=0.0, posinf=0.0, neginf=0.0)


# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L393
def focal_cross_entropy_loss(
    pred: torch.Tensor,
    gt: torch.Tensor,
    weight: torch.Tensor, # This is for PQ-loss weighting
    focal_loss_alpha: float = 0.75,
    focal_loss_gamma: float = 0.0,
    background_channel_index: int = -1):
    """
    pred: B x N x C
    gt: B x N
    weight: B x N
    """
    pred = pred.transpose(1, 2) # B x C x N
    gt = F.one_hot(gt, num_classes=pred.shape[1]).transpose(1, 2).to(pred) # B x C x N
    loss = F.cross_entropy(pred, gt, reduction="none") # B x N
    if focal_loss_gamma == 0.0:
        focal_loss = loss
    else:
        pred = F.softmax(pred, dim=1) # B x C x N
        pt = (pred * gt).sum(1)  # B x N
        focal_loss = torch.pow(1.0 - pt, focal_loss_gamma) * loss # B x N
    
    if focal_loss_alpha >= 0:
        alpha_weights = (
          focal_loss_alpha * (1.0 - gt[:, background_channel_index])
          + (1 - focal_loss_alpha) * gt[:, background_channel_index]) # B x N
        focal_loss = alpha_weights * focal_loss # B x N
    
    focal_loss = focal_loss * weight # B x N
    focal_loss = focal_loss.flatten(1)
    num_non_zero = (focal_loss != 0.0).to(focal_loss).sum(-1) # B
    num_non_zero = torch.clamp(num_non_zero, min=1.0)
    loss_sum_per_sample = focal_loss.sum(-1) # B
    return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1


# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L50
def _gumbel_topk_sample(logits: torch.Tensor, k: int):
    """Samples k points from the softmax distribution with Gumbel-Top-k trick."""
    # Note that torch.rand is [0, 1), we need to make it (0, 1) to ensure the log is valid.
    gumbel_noise = torch.rand(size=logits.shape, dtype=logits.dtype, device=logits.device)
    gumbel_noise = -torch.log(-torch.log(gumbel_noise))
    _, indices = torch.topk(logits + gumbel_noise, k)
    return indices


# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L576
def pixelwise_insdis_loss(
    pixel_feature: torch.Tensor,
    gt_mask: torch.Tensor,
    sample_temperature: float,
    sample_k: int,
    instance_discrimination_temperature: float,
    pixel_gt_void_mask: torch.Tensor,
    inverse_gt_mask_area: torch.Tensor
    ):
    
    # pixel_feature: B x C x H x W
    # gt_mask: B x N x H x W
    pixel_feature = pixel_feature.flatten(2) # B x C x HW
    gt_mask = gt_mask.flatten(2) # B x N x HW
    pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
    inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW

    sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW
    # sample_logits.masked_fill_(pixel_gt_void_mask, float('-inf'))
    sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT

    sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K
    # Sample ground truth one-hot encodings and compute gt_similarity.
    pixel_gt_sampled_feature = torch.gather(gt_mask, dim=2, index=sample_indices.unsqueeze(1).repeat(1, gt_mask.shape[1], 1)) # B x N x K
    sampled_gt_similarity = torch.einsum('bnk,bnj->bkj', pixel_gt_sampled_feature, pixel_gt_sampled_feature) # B x K x K

    # Normalize the ground truth similarity into a distribution (sum to 1).
    pixel_normalizing_constant = sampled_gt_similarity.sum(dim=1, keepdim=True) # B x 1 x K
    sampled_gt_similarity /= torch.clamp(pixel_normalizing_constant, min=1.0) # B x K x K

    # Sample predicted features and compute pred_similarity.
    pixel_pred_sampled_feature = torch.gather(pixel_feature, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pixel_feature.shape[1], 1)) # B x C x K
    sampled_pred_similarity = torch.einsum('bck,bcj->bkj', pixel_pred_sampled_feature, pixel_pred_sampled_feature) # B x K x K
    sampled_pred_similarity /= instance_discrimination_temperature # B x K x K
    loss = F.cross_entropy(sampled_pred_similarity, sampled_gt_similarity, reduction="none") # B x K

    num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
    num_non_zero = torch.clamp(num_non_zero, min=1.0)
    loss_sum_per_sample = loss.sum(-1) # B
    return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1


def aux_semantic_loss(
    pred_semantic_logits: torch.Tensor,
    ground_truth_semantic: torch.Tensor,
    sample_temperature: float,
    sample_k: int,
    pixel_gt_void_mask: torch.Tensor,
    inverse_gt_mask_area: torch.Tensor,
    num_classes: int):

    # The pred maybe in lower resolution, we downsample gt beforehand.
    if pred_semantic_logits.shape[-2:] != ground_truth_semantic.shape[-2:]:
        assert (ground_truth_semantic.shape[-1] - 1) // (pred_semantic_logits.shape[-1] - 1) == (ground_truth_semantic.shape[-2] - 1) // (pred_semantic_logits.shape[-2] - 1)
        stride = (ground_truth_semantic.shape[-1] - 1) // (pred_semantic_logits.shape[-1] - 1)
        ground_truth_semantic = ground_truth_semantic[:, ::stride, ::stride]
        pixel_gt_void_mask = pixel_gt_void_mask[:, ::stride, ::stride]
        inverse_gt_mask_area = inverse_gt_mask_area[:, ::stride, ::stride]

    pred_semantic_logits = pred_semantic_logits.flatten(2) # B x C x HW
    ground_truth_semantic = ground_truth_semantic.flatten(1) # B x HW
    pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
    inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW
    if sample_k == 0:
        # This falls back to normal cross-entropy loss
        sampled_ground_truth_semantic = ground_truth_semantic # B x HW
        sampled_pred_semantic_logits = pred_semantic_logits # B x C x HW
    else:
        sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW
        sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT
        sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K
        sampled_ground_truth_semantic = torch.gather(ground_truth_semantic, dim=1, index=sample_indices) # B x K
        sampled_pred_semantic_logits = torch.gather(pred_semantic_logits, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pred_semantic_logits.shape[1], 1)) # B x C x K
    # ignore the class index num_classes.
    keep_mask = (sampled_ground_truth_semantic != num_classes) # B x K
    loss = F.cross_entropy(sampled_pred_semantic_logits, sampled_ground_truth_semantic, ignore_index=num_classes, reduction='none') # B x K
    loss = loss * keep_mask.to(loss)
    num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
    num_non_zero = torch.clamp(num_non_zero, min=1.0)
    loss_sum_per_sample = loss.sum(-1) # B
    return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1


# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L56
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L510
def dice_loss(
        inputs: torch.Tensor,
        targets: torch.Tensor,
        pixel_gt_void_mask: torch.Tensor,
        matched_cls_prob: torch.Tensor,
        masking_void_pixel: bool = True
    ):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.softmax(1) # B N HW
    if masking_void_pixel:
        # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L111
        inputs = inputs.masked_fill(pixel_gt_void_mask.unsqueeze(1), 0) # remove void pixels.
    smooth = 1.0
    intersection = 2 * (inputs * targets).sum(-1) + smooth # B x N
    denominator = inputs.sum(-1) + targets.sum(-1) + smooth # B x N
    loss = 1.0 - divide_no_nan(intersection, denominator)
    loss *= matched_cls_prob
    # Note: kMaX-DeepLab sum over num_masks and avg over batches. But here batch and num_mask are one
    # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L559
    # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L402
    # As the existing of modifer, it equals to multiplier by 0.75
    return (loss.sum(1) * 0.75/inputs.shape[1]).mean() # sum over masks and mean over batches.


def softmax_ce_loss(
        inputs: torch.Tensor,
        targets: torch.Tensor,
        pixel_gt_void_mask: torch.Tensor,
        masking_void_pixel: bool = True
    ):
    """
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    Returns:
        Loss tensor
    """
    loss = F.cross_entropy(inputs, targets, reduction="none") # B x HW
    loss = loss.masked_fill(pixel_gt_void_mask, 0) # remove void pixels.

    num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
    num_non_zero = torch.clamp(num_non_zero, min=1.0)
    loss_sum_per_sample = loss.sum(-1) # B
    return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1


class SetCriterion(nn.Module):
    """This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """

    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, share_final_matching,
                 pixel_insdis_temperature=1.5, pixel_insdis_sample_k=4096,
                 aux_semantic_temperature=2.0, aux_semantic_sample_k=4096,
                 masking_void_pixel=True):
        """Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        self.share_final_matching = share_final_matching
        self.pixel_insdis_temperature = pixel_insdis_temperature
        self.pixel_insdis_sample_k = pixel_insdis_sample_k
        self.aux_semantic_temperature = aux_semantic_temperature
        self.aux_semantic_sample_k = aux_semantic_sample_k
        self.masking_void_pixel = masking_void_pixel

    def loss_labels(self, outputs, targets):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert "pred_logits" in outputs
        src_logits = outputs["pred_logits"] # B x N x C
        target_classes = targets["labels"] # B x N
        pq_loss_class_weight = targets["pq_loss_class_weight"]
        losses = {"loss_ce": focal_cross_entropy_loss(src_logits, target_classes, pq_loss_class_weight)}
        return losses
    
    def loss_masks(self, outputs, targets):
        """Compute the losses related to the masks: the focal loss and the dice loss.
        targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        src_masks = outputs["pred_masks"] # B x N x H x W
        target_masks = targets["masks"]
        pq_loss_mask_weight = targets["pq_loss_mask_weight"]
        pixel_gt_void_mask = targets["pixel_gt_void_mask"]

        src_masks = src_masks.flatten(2) # B x N x HW
        target_masks = target_masks.flatten(2) # B x N x HW
        pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW

        losses = {
            "loss_mask": softmax_ce_loss(src_masks, target_masks, pixel_gt_void_mask, masking_void_pixel=self.masking_void_pixel),
            "loss_dice": dice_loss(src_masks, target_masks, pixel_gt_void_mask, pq_loss_mask_weight, masking_void_pixel=self.masking_void_pixel),
        }

        return losses

    def loss_pixels(self, outputs, targets):
        pixel_feature = outputs["pixel_feature"]
        target_masks = targets["masks"]
        pixel_gt_void_mask = targets["pixel_gt_void_mask"]
        inverse_gt_mask_area = targets["inverse_gt_mask_area"]

        losses = {"loss_pixel_insdis": pixelwise_insdis_loss(
            pixel_feature=pixel_feature,
            gt_mask=target_masks,
            sample_temperature=self.pixel_insdis_temperature,
            sample_k=self.pixel_insdis_sample_k,
            instance_discrimination_temperature=0.3,
            pixel_gt_void_mask=pixel_gt_void_mask,
            inverse_gt_mask_area=inverse_gt_mask_area
            )}

        del target_masks
        return losses

    def loss_semantic(self, outputs, targets):
        pred_semantic_logits = outputs["aux_semantic_pred"]
        ground_truth_semantic = targets["ground_truth_semantic"]
        pixel_gt_void_mask = targets["pixel_gt_void_mask"]
        inverse_gt_mask_area = targets["inverse_gt_mask_area"]

        losses = {"loss_aux_semantic": aux_semantic_loss(
            pred_semantic_logits=pred_semantic_logits,
            ground_truth_semantic=ground_truth_semantic,
            sample_temperature=self.aux_semantic_temperature,
            sample_k=self.aux_semantic_sample_k,
            pixel_gt_void_mask=pixel_gt_void_mask,
            inverse_gt_mask_area=inverse_gt_mask_area,
            num_classes=self.num_classes
        )}
        return losses

    @torch.no_grad()
    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        # torch.full_like gives a tensor full of i in shape of src.shape
        # at each iter, i is the index, src is the src ind in shape of (N)
        # so batch_idx is concat of (0,0,...), (1,1,...), with shape (N0+N1+N2+...+Nb)
        # so if we flatten gt/pred across bathces, this gives the batch_id of each sample
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        # src_idx is src_ind concated to shape (N0+N1+N2+...+Nb)
        # it is a flattened concat of mask_id at each batch
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx


    def get_loss(self, loss, outputs, targets):
        loss_map = {
            'labels': self.loss_labels,
            'masks': self.loss_masks,
            'pixels': self.loss_pixels,
            'aux_semantic': self.loss_semantic,
        }
        assert loss in loss_map, f"do you really want to compute {loss} loss?"
        return loss_map[loss](outputs, targets)

    @torch.no_grad()
    def process_gt(self, outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=False):
        # Permute&Pad Pred&GT for loss compuation.
        # By controling process_gt, we can share the matching results for all preds.
        src_idx = self._get_src_permutation_idx(indices)

        src_masks = outputs["pred_masks"].detach() # B x N x H x W

        # Pad and permute the target_mask to B x N x H x W
        target_masks = torch.zeros_like(src_masks)
        target_masks_o = torch.cat([t["masks"][J] for t, (_, J) in zip(targets, indices)]).to(target_masks)
        target_masks[src_idx] = target_masks_o

        # Note that for instance segmentation masks may overlap with each other, here we normalize
        # the mask to ensure they sum to one
        target_masks = target_masks / torch.clamp(target_masks.sum(1, keepdim=True), min=1.0)

        # Pad and permute the matched_cls_prob to B x N
        matched_cls_prob_o = torch.cat([cls_prob for cls_prob in matched_cls_prob])
        matched_cls_prob_o = torch.clamp(matched_cls_prob_o, min=self.eos_coef)
        # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L1034
        # no penalty for unmatched masks.
        matched_cls_prob = torch.full(
            src_masks.shape[:2], 0, dtype=src_masks.dtype, device=src_masks.device
        ) # B x N
        matched_cls_prob[src_idx] = matched_cls_prob_o.to(matched_cls_prob)

        # pixel_gt_void_mask is used to indicate those pixels without labels.
        pixel_gt_void_mask = (target_masks.sum(1) < 1) # B x H x W
   
        # inverse_gt_mask_area is used to sample pixels.
        mask_gt_area = target_masks.sum(2).sum(2) # B x N
        pixel_gt_area = torch.einsum('bnhw,bn->bhw', target_masks, mask_gt_area) # B x H x W
        inverse_gt_mask_area = (pixel_gt_area.shape[1] * pixel_gt_area.shape[2]) / torch.clamp(pixel_gt_area, min=1.0) # B x H x W

        src_logits = outputs["pred_logits"] # B x N x C
        # Pad and permute the target_classes to B x N
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        # This serves as a padding.
        target_classes = torch.full(
            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
        )
        # We put real GT to those corresponds to src_idx, and put void into other places.
        target_classes[src_idx] = target_classes_o

        src_masks_prob = src_masks.softmax(1)
        void_mask = pixel_gt_void_mask.to(src_masks_prob) # B x H x W
        # compute iou instead of dice for void overlapping.
        def computer_iou_score(x, y):
            # x : B x N x H x W
            # y : B x H x W
            x = x.flatten(2) # B x N x L
            y = y.flatten(1) # B x L
            intersection = torch.einsum('bnl,bl->bn', x, y) # B x N
            denominator = x.sum(-1) # B x N
            return intersection / (denominator + 1e-5) # B x N

        # Pad and permute the matched_dice to B x N
        matched_dice_o = torch.cat([dice for dice in matched_dice])
        matched_dice = computer_iou_score(src_masks_prob, void_mask) # unmatched masks use their dice with void
        matched_dice[src_idx] = matched_dice_o.to(matched_dice)
        matched_dice = torch.clamp(matched_dice, min=self.eos_coef)

        
        processed_gt = {"masks": target_masks, "labels": target_classes,
            "pq_loss_mask_weight": matched_cls_prob,
            "pq_loss_class_weight": matched_dice,
            "pixel_gt_void_mask": pixel_gt_void_mask,
            "inverse_gt_mask_area": inverse_gt_mask_area,}
    
        if process_semantic and "semantic_masks" in targets[0]:
            # To obtain semantic gt
            ground_truth_semantic = [t["semantic_masks"] for t in targets]
            ground_truth_semantic = torch.stack(ground_truth_semantic, dim=0) # B x H x W
            # self.num_classes is set to ignore label
            ground_truth_semantic[ground_truth_semantic==-1] = self.num_classes
            processed_gt.update({"ground_truth_semantic": ground_truth_semantic})

        return processed_gt


    def forward(self, outputs, targets):
        """This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
        indices, matched_dice, matched_cls_prob = self.matcher(outputs_without_aux, targets)
        # Pad GT to the same number of prediction.
        processed_targets = self.process_gt(outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=True)
        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, processed_targets))

        if "aux_outputs" in outputs:
            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
                # We share matching results across predictions.
                if not self.share_final_matching:
                    indices, matched_dice, matched_cls_prob = self.matcher(aux_outputs, targets)
                if not self.share_final_matching:
                    processed_targets = self.process_gt(aux_outputs, targets, indices, matched_dice, matched_cls_prob)
                for loss in self.losses:
                    if loss in ['aux_semantic']:
                        # Only for final output.
                        continue
                    l_dict = self.get_loss(loss, aux_outputs, processed_targets)
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)
        return losses

    def __repr__(self):
        head = "Criterion " + self.__class__.__name__
        body = [
            "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
            "losses: {}".format(self.losses),
            "weight_dict: {}".format(self.weight_dict),
            "num_classes: {}".format(self.num_classes),
            "eos_coef: {}".format(self.eos_coef),
        ]
        _repr_indent = 4
        lines = [head] + [" " * _repr_indent + line for line in body]
        return "\n".join(lines)